import os
import random
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sn
from tqdm import tqdm
from PIL import Image, ImageDraw
 

def plot_labels(labels, names=(), save_dir='./runs/'):
    # plot dataset labels
    print('Plotting labels... ')
    c, b = labels[:, 0], labels[:, 1:].transpose()  # classes, boxes
    nc = int(c.max() + 1)  # number of classes

 
    x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
 
    # seaborn correlogram
    sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
    plt.savefig(os.path.join(save_dir, 'labels_correlogram.jpg'), dpi=200)
    plt.close()
 
    # matplotlib labels
    matplotlib.use('svg')  # faster
    ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
    y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
    for i in range(len(y[0])):
    #  xy 即为在图上的坐标  text 为内容
        ax[0].annotate(text=int(y[0][i]), xy=(y[1][i]+0.3, y[0][i]))
    data = {'category':names, 'nums':y[0]}
    df0 = pd.DataFrame(data)
    df0.to_excel(os.path.join(save_dir, 'labels.xlsx'), na_rep=False)
 
    # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)]  # update colors bug #3195
    ax[0].set_ylabel('instances')
    if 0 < len(names) < 30:
        ax[0].set_xticks(range(len(names)))
        ax[0].set_xticklabels(names, rotation=90, fontsize=10)
    else:
        ax[0].set_xlabel('classes')
 
    sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
    sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
 
    # rectangles
    labels[:, 1:3] = 0.5  # center
    labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
    img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
    for cls, *box in labels[:1000]:
        ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls)])  # plot
    ax[1].imshow(img)
    ax[1].axis('off')
 
    for a in [0, 1, 2, 3]:
        for s in ['top', 'right', 'left', 'bottom']:
            ax[a].spines[s].set_visible(False)
 
    plt.savefig(os.path.join(save_dir, 'labels.jpg'), dpi=200)
    matplotlib.use('Agg')
    plt.close()
 
def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y
 
path = r'E:\baiduyun\fall\FallDataset\labels'
shapes = [] 
ids = []  
file_list = os.listdir(path)
with open(os.path.join(path, 'classes.txt'), 'r') as cf:
    category = [i.replace("\n","") for i in cf.read().strip().split()]
num_classes = len(category) 
print(num_classes, category)

colors = [(random.randint(0,255),random.randint(0,255),random.randint(0,255)) for _ in range(num_classes)] 
file_list.remove('classes.txt')

for file in tqdm(file_list):
    with open(os.path.join(path, file), 'r') as f:
        for l in f.readlines():
            l = l.replace("\n","")
            if len(l) != 0:
                line = l.split(' ')  # ['11' '0.724877' '0.309082' '0.073938' '0.086914']
                try:
                    ids.append([int(line[0])])
                    shapes.append(list(map(float, line[1:])))
                except:
                    print(l,len(l))
                    print(line)
 
shapes = np.array(shapes)
ids = np.array(ids)
lbs = np.hstack((ids, shapes))
 
# print(lbs)
plot_labels(labels=lbs, names=np.array(category))